""" Diffusion Models """

from functools import partial
from typing import Any, Tuple, List, Dict, Union, Type, Optional, Callable

import jax
import jax.numpy as jnp
from jax import nn
import haiku as hk
import einops

from sb3_jax.common.jax_utils import jax_print
from sb3_jax.du.models import DDPMCoefficients
from diffgro.common.models.utils import (
    act2fn,
    init_he_uniform,
    init_he_normal,
)
from diffgro.common.models.helpers import (
    TimeEmb,
    DownSample1D,
    UpSample1D,
    Conv1DBlock,
    Residual,
    PreNorm,
)


# =================================== MLP-Diffusion ================================ #


class MLPDiffusion(hk.Module):
    def __init__(
        self,
        emb_dim: int,
        out_dim: int,
        net_arch: List[int],
        batch_keys: List[str],
        activation_fn: str = "mish",
    ):
        super().__init__()
        self.emb_dim = emb_dim
        self.out_dim = self.noise_dim = out_dim
        self.net_arch = net_arch
        self.batch_keys = batch_keys
        self.activation_fn = activation_fn

    def __call__(
        self, x_t: jax.Array, batch_dict: Dict[str, jax.Array], t: jax.Array
    ) -> jax.Array:
        """
        x_t : [batch_size, dim]
        t   : [batch_size, 1]
        """
        con = []
        # condition embedding
        for key in self.batch_keys:
            c = batch_dict[key]
            c = hk.Linear(
                self.emb_dim * 2, name=f"mlp_{key}_emb_0", **init_he_normal()
            )(c)
            c = act2fn[self.activation_fn](c)
            c = hk.Linear(self.emb_dim, name=f"mlp_{key}_emb_1", **init_he_normal())(c)
            con.append(c)
        # timestep embedding
        t = TimeEmb(self.emb_dim, self.activation_fn)(t)
        con.append(t)
        con = jnp.concatenate(con, axis=-1)

        # x_t embedding
        x_t = hk.Linear(self.emb_dim * 2, name=f"mlp_xt_emb_0", **init_he_normal())(x_t)
        x_t = act2fn[self.activation_fn](x_t)
        x_t = hk.Linear(self.emb_dim, name=f"mlp_xt_emb_1", **init_he_normal())(x_t)

        # -> [batch, dim]
        inp = jnp.concatenate((x_t, con), axis=-1)
        out = hk.Linear(self.net_arch[0], name=f"mlp_0", **init_he_normal())(inp)
        out = act2fn[self.activation_fn](out)

        for ind, dim in enumerate(self.net_arch[1:]):
            inp = jnp.concatenate((out / 1.414, con), axis=-1)
            _out = hk.Linear(dim, name=f"mlp_{ind+1}", **init_he_normal())(inp)
            out = act2fn[self.activation_fn](_out) + out / 1.414

        inp = jnp.concatenate((out, con), axis=-1)
        out = hk.Linear(self.out_dim, name=f"last", **init_he_normal())(inp)
        return out


# =================================== UNet-Diffusion ================================ #


class ResidualTemporalBlock(hk.Module):
    def __init__(
        self, out_channels: int, kernel_size: int = 5, activation_fn: str = "mish"
    ):
        super().__init__()
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.activation_fn = activation_fn

    def __call__(self, x: jax.Array, t: jax.Array):
        """
        x : [batch_size, inp_channels, horizon]
        t : [batch_size, emb_dim]
        returns : [batch_size, out_channels, horizon]
        """
        inp_channels = x.shape[1]

        # denoise timestep embed
        t = act2fn[self.activation_fn](t)
        t = hk.Linear(self.out_channels)(t)
        t = einops.rearrange(t, "b t-> b t 1")

        out = Conv1DBlock(self.out_channels, self.kernel_size)(x) + t
        out = Conv1DBlock(self.out_channels, self.kernel_size)(out)

        if inp_channels == self.out_channels:
            residual = x
        else:
            residual = hk.Conv1D(self.out_channels, 1, data_format="NCW")(x)

        return out + residual


class UNetDiffusion(hk.Module):
    def __init__(
        self,
        horizon: int,
        emb_dim: int,
        out_dim: int,
        dim_mults: Tuple[int],  # net_arch, ex) (1,2,4,8)
        attention: bool,
        batch_keys: List[str],
        activation_fn: str = "mish",
    ):
        super().__init__()
        self.horizon = horizon
        self.emb_dim = emb_dim
        self.out_dim = out_dim
        self.noise_dim = (horizon, out_dim)
        self.dim_mults = dim_mults
        self.attention = attention
        self.batch_keys = batch_keys
        self.activation_fn = activation_fn

        self.dims = [out_dim, *map(lambda m: emb_dim * m, dim_mults)]
        self.in_out = list(zip(self.dims[:-1], self.dims[1:]))
        # print(f'[models/temporal] Channel dimensions: {self.in_out}')
        self.num_resolutions = len(self.in_out)

    def __call__(self, x_t: jax.Array, batch_dict: Dict[str, jax.Array], t: jax.Array):
        """
        x_t : [batch_size, horizon, dim]
        t   : [batch_size, 1]
        """
        x_t = einops.rearrange(x_t, "b h ts -> b ts h")

        con = []
        # condition embedding
        for key in self.batch_keys:
            c = batch_dict[key]
            c = hk.Linear(
                self.emb_dim * 2, name=f"mlp_{key}_emb_0", **init_he_normal()
            )(c)
            c = act2fn[self.activation_fn](c)
            c = hk.Linear(self.emb_dim, name=f"mlp_{key}_emb_1", **init_he_normal())(c)
            con.append(c)
        # timestep embedding
        t = TimeEmb(self.emb_dim, self.activation_fn)(t)
        con.append(t)
        con = jnp.concatenate(con, axis=-1)

        h = []
        # down sample
        horizon = self.horizon
        for ind, (dim_in, dim_out) in enumerate(self.in_out):
            is_last = ind >= (self.num_resolutions - 1)
            x_t = ResidualTemporalBlock(dim_out, activation_fn=self.activation_fn)(
                x_t, con
            )
            x_t = ResidualTemporalBlock(dim_out, activation_fn=self.activation_fn)(
                x_t, con
            )
            if self.attention:
                x_t = Residual(PreNorm(dim_out, LinearAttention(dim_out)))(x_t)
            h.append(x_t)
            x_t = DownSample1D(dim_out)(x_t) if not is_last else x_t
            if not is_last:
                horizon = horizon // 2

        # middle block
        mid_dim = self.dims[-1]
        x_t = ResidualTemporalBlock(mid_dim, activation_fn=self.activation_fn)(x_t, con)
        if self.attention:
            x_t = Residual(PreNorm(mid_dim, LinearAttention(mid_dim)))(x_t)
        x_t = ResidualTemporalBlock(mid_dim, activation_fn=self.activation_fn)(x_t, con)

        # up sample
        for ind, (dim_in, dim_out) in enumerate(reversed(self.in_out[1:])):
            is_last = ind >= (self.num_resolutions - 1)

            x_t = jnp.concatenate((x_t, h.pop()), axis=1)
            x_t = ResidualTemporalBlock(dim_in, activation_fn=self.activation_fn)(
                x_t, con
            )
            x_t = ResidualTemporalBlock(dim_in, activation_fn=self.activation_fn)(
                x_t, con
            )
            if self.attention:
                x_t = Residual(PreNorm(dim_in, LinearAttention(dim_in)))(x_t)
            x_t = UpSample1D(dim_in)(x_t) if not is_last else x_t
            if not is_last:
                horizon = horizon * 2

        x_t = Conv1DBlock(self.emb_dim, kernel_size=5)(x_t)
        x_t = hk.Conv1D(self.out_dim, 1, data_format="NCW", name="last")(x_t)

        out = einops.rearrange(x_t, "b ts h -> b h ts")
        return out


# =================================== Diffusion Module ================================ #


class Diffusion(hk.Module):
    def __init__(
        self,
        diffusion: hk.Module,
        n_denoise: int,
        ddpm_dict: DDPMCoefficients,
        guidance_weight: float = 1.0,  # default: no classifier-free guidance
        predict_epsilon: bool = False,
        denoise_type: str = "ddpm",
    ):
        super().__init__()
        self.diffusion = diffusion
        self.denoise_type = denoise_type
        self.n_denoise = n_denoise
        self.noise_dim = (
            self.diffusion.noise_dim
            if isinstance(self.diffusion.noise_dim, tuple)
            else (self.diffusion.noise_dim,)
        )

        self.guidance_weight = guidance_weight
        self.predict_epsilon = predict_epsilon
        # print(f"[model/diffusion] predict epsilon: {self.predict_epsilon}")
        # print(f"[model/diffuison] cf guide weight: {self.guidance_weight}")

        # scheduler params
        self.beta_t = ddpm_dict.beta_t
        self.alpha_t = ddpm_dict.alpha_t
        self.oneover_sqrta = ddpm_dict.oneover_sqrta
        self.sqrt_beta_t = ddpm_dict.sqrt_beta_t
        self.alpha_bar_t = ddpm_dict.alpha_bar_t
        self.alpha_bar_prev_t = ddpm_dict.alpha_bar_prev_t
        self.sqrtab = ddpm_dict.sqrtab
        self.sqrtmab = ddpm_dict.sqrtmab
        self.ma_over_sqrtmab_inv = ddpm_dict.ma_over_sqrtmab_inv
        self.posterior_log_beta = ddpm_dict.posterior_log_beta
        self.posterior_mean_coef1 = ddpm_dict.posterior_mean_coef1
        self.posterior_mean_coef2 = ddpm_dict.posterior_mean_coef2

    def __call__(
        self,
        x_t: jax.Array,
        batch_dict: Dict[str, jax.Array],
        t: jax.Array,
        denoise: bool = False,
        deterministic: bool = False,
    ) -> Tuple[jax.Array, Dict[str, jax.Array]]:

        # denoising chain
        if denoise:
            batch_size = batch_dict["obs"].shape[0]
            # sample initial noise, x_T ~ N(0, 1)
            x_i = jax.random.normal(
                hk.next_rng_key(), shape=(batch_size,) + self.noise_dim
            )

            # trace denoised outputs
            x_i_trace = dict()
            x_i_trace[self.n_denoise] = (x_i, None)

            # denoise
            for i in range(self.n_denoise, 0, -1):
                t_i = jnp.array([[i]])
                t_i = jnp.repeat(t_i, batch_size, axis=0)
                noise = (
                    jax.random.normal(
                        hk.next_rng_key(), shape=(batch_size,) + self.noise_dim
                    )
                    if (i > 1 and not deterministic)
                    else 0
                )

                # epsilon prediction
                eps = self._predict_eps(x_i, batch_dict, t_i, denoise=True)

                if self.denoise_type == "ddpm":
                    if self.predict_epsilon:
                        pred_x_0 = (x_i - self.sqrtmab[i] * eps) / self.sqrtab[
                            i
                        ]  # x_0 prediction
                        x_i = (
                            self.oneover_sqrta[i]
                            * (x_i - self.ma_over_sqrtmab_inv[i] * eps)
                            + self.sqrt_beta_t[i] * noise
                        )
                    else:
                        x_i = (
                            self.posterior_mean_coef1[i] * eps
                            + self.posterior_mean_coef2[i] * x_i
                            + jnp.exp(0.5 * self.posterior_log_beta[i]) * noise
                        )
                elif self.denoise_type == "ddim":
                    if self.predict_epsilon:
                        pred_x_0 = (x_i - self.sqrtmab[i] * eps) / self.sqrtab[
                            i
                        ]  # x_0 prediction
                    else:
                        pred_x_0 = eps
                        eps = (x_i - self.sqrtab[i] * pred_x_0) / self.sqrtmab[i]
                    x_i = (
                        jnp.sqrt(self.alpha_bar_prev_t[i]) * pred_x_0
                        + jnp.sqrt(1.0 - self.alpha_bar_prev_t[i]) * eps
                    )

                x_i_trace[i - 1] = (x_i, eps)
            return x_i, x_i_trace
        return self._predict_eps(x_t, batch_dict, t, denoise=False), {}

    def _predict_eps(
        self,
        x_t: jax.Array,
        batch_dict: Dict[str, jax.Array],
        t: jax.Array,
        denoise: bool = False,
    ) -> jax.Array:
        eps = self.diffusion(x_t, batch_dict, t)
        if self.guidance_weight != 1.0:
            batch_null = {}
            for key in self.diffusion.batch_keys:
                batch_null[key] = jnp.zeros_like(batch_dict[key])
            eps_null = self.diffusion(x_t, batch_null, jnp.zeros_like(t))
            return self.guidance_weight * eps + (1 - self.guidance_weight) * eps_null
        return eps
